Improve cancel generation with robust UI state management and orchestrator pattern
Browse files
app.py
CHANGED
|
@@ -334,7 +334,7 @@ def load_pipeline(model_name):
|
|
| 334 |
model=repo,
|
| 335 |
tokenizer=tokenizer,
|
| 336 |
trust_remote_code=True,
|
| 337 |
-
|
| 338 |
device_map="auto",
|
| 339 |
use_cache=True, # Enable past-key-value caching
|
| 340 |
token=access_token)
|
|
@@ -509,12 +509,14 @@ def chat_response(user_msg, chat_history, system_prompt,
|
|
| 509 |
thought_buf = ''
|
| 510 |
answer_buf = ''
|
| 511 |
in_thought = False
|
|
|
|
| 512 |
|
| 513 |
# Stream tokens
|
| 514 |
for chunk in streamer:
|
| 515 |
# Check for cancellation signal
|
| 516 |
if cancel_event.is_set():
|
| 517 |
-
history[-1]['
|
|
|
|
| 518 |
yield history, debug
|
| 519 |
break
|
| 520 |
|
|
@@ -523,21 +525,14 @@ def chat_response(user_msg, chat_history, system_prompt,
|
|
| 523 |
# Detect start of thinking
|
| 524 |
if not in_thought and '<think>' in text:
|
| 525 |
in_thought = True
|
| 526 |
-
|
| 527 |
-
|
| 528 |
-
'role': 'assistant',
|
| 529 |
-
'content': '',
|
| 530 |
-
'metadata': {'title': '💭 Thought'}
|
| 531 |
-
})
|
| 532 |
-
# Capture after opening tag
|
| 533 |
after = text.split('<think>', 1)[1]
|
| 534 |
thought_buf += after
|
| 535 |
-
# If closing tag in same chunk
|
| 536 |
if '</think>' in thought_buf:
|
| 537 |
before, after2 = thought_buf.split('</think>', 1)
|
| 538 |
history[-1]['content'] = before.strip()
|
| 539 |
in_thought = False
|
| 540 |
-
# Start answer buffer
|
| 541 |
answer_buf = after2
|
| 542 |
history.append({'role': 'assistant', 'content': answer_buf})
|
| 543 |
else:
|
|
@@ -545,14 +540,12 @@ def chat_response(user_msg, chat_history, system_prompt,
|
|
| 545 |
yield history, debug
|
| 546 |
continue
|
| 547 |
|
| 548 |
-
# Continue thought streaming
|
| 549 |
if in_thought:
|
| 550 |
thought_buf += text
|
| 551 |
if '</think>' in thought_buf:
|
| 552 |
before, after2 = thought_buf.split('</think>', 1)
|
| 553 |
history[-1]['content'] = before.strip()
|
| 554 |
in_thought = False
|
| 555 |
-
# Start answer buffer
|
| 556 |
answer_buf = after2
|
| 557 |
history.append({'role': 'assistant', 'content': answer_buf})
|
| 558 |
else:
|
|
@@ -561,8 +554,10 @@ def chat_response(user_msg, chat_history, system_prompt,
|
|
| 561 |
continue
|
| 562 |
|
| 563 |
# Stream answer
|
| 564 |
-
if not answer_buf:
|
| 565 |
history.append({'role': 'assistant', 'content': ''})
|
|
|
|
|
|
|
| 566 |
answer_buf += text
|
| 567 |
history[-1]['content'] = answer_buf
|
| 568 |
yield history, debug
|
|
@@ -573,7 +568,6 @@ def chat_response(user_msg, chat_history, system_prompt,
|
|
| 573 |
history.append({'role': 'assistant', 'content': f"Error: {e}"})
|
| 574 |
yield history, debug
|
| 575 |
finally:
|
| 576 |
-
# Final cleanup
|
| 577 |
gc.collect()
|
| 578 |
|
| 579 |
|
|
@@ -583,21 +577,14 @@ def update_default_prompt(enable_search):
|
|
| 583 |
def update_duration_estimate(model_name, enable_search, max_results, max_chars, max_tokens, search_timeout):
|
| 584 |
"""Calculate and format the estimated GPU duration for current settings."""
|
| 585 |
try:
|
| 586 |
-
|
| 587 |
-
dummy_msg = ""
|
| 588 |
-
dummy_history = []
|
| 589 |
-
dummy_system_prompt = ""
|
| 590 |
-
|
| 591 |
duration = get_duration(dummy_msg, dummy_history, dummy_system_prompt,
|
| 592 |
enable_search, max_results, max_chars, model_name,
|
| 593 |
max_tokens, 0.7, 40, 0.9, 1.2, search_timeout)
|
| 594 |
-
|
| 595 |
model_size = MODELS[model_name].get("params_b", 4.0)
|
| 596 |
-
|
| 597 |
-
|
| 598 |
-
|
| 599 |
-
f"📊 **Model Size:** {model_size:.1f}B parameters\n" \
|
| 600 |
-
f"🔍 **Web Search:** {'Enabled' if enable_search else 'Disabled'}"
|
| 601 |
except Exception as e:
|
| 602 |
return f"⚠️ Error calculating estimate: {e}"
|
| 603 |
|
|
@@ -613,10 +600,7 @@ with gr.Blocks(title="LLM Inference with ZeroGPU") as demo:
|
|
| 613 |
search_chk = gr.Checkbox(label="Enable Web Search", value=False)
|
| 614 |
sys_prompt = gr.Textbox(label="System Prompt", lines=3, value=update_default_prompt(search_chk.value))
|
| 615 |
|
| 616 |
-
|
| 617 |
-
duration_display = gr.Markdown(value=update_duration_estimate(
|
| 618 |
-
"Qwen3-1.7B", False, 4, 50, 1024, 5.0
|
| 619 |
-
))
|
| 620 |
|
| 621 |
gr.Markdown("### Generation Parameters")
|
| 622 |
max_tok = gr.Slider(64, 16384, value=1024, step=32, label="Max Tokens")
|
|
@@ -641,58 +625,75 @@ with gr.Blocks(title="LLM Inference with ZeroGPU") as demo:
|
|
| 641 |
|
| 642 |
# Group all inputs for cleaner event handling
|
| 643 |
chat_inputs = [txt, chat, sys_prompt, search_chk, mr, mc, model_dd, max_tok, temp, k, p, rp, st]
|
| 644 |
-
|
| 645 |
-
|
| 646 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 647 |
yield {
|
|
|
|
|
|
|
|
|
|
| 648 |
submit_btn: gr.update(interactive=False),
|
| 649 |
cancel_btn: gr.update(visible=True),
|
| 650 |
-
txt: gr.update(interactive=False, value=""), # Clear textbox and disable
|
| 651 |
}
|
| 652 |
-
|
| 653 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 654 |
yield {
|
| 655 |
-
|
| 656 |
-
|
|
|
|
| 657 |
}
|
| 658 |
|
| 659 |
-
def reset_ui_after_generation():
|
| 660 |
-
# Update UI back to "idle" state
|
| 661 |
-
return {
|
| 662 |
-
submit_btn: gr.update(interactive=True),
|
| 663 |
-
cancel_btn: gr.update(visible=False),
|
| 664 |
-
txt: gr.update(interactive=True), # Re-enable textbox
|
| 665 |
-
}
|
| 666 |
-
|
| 667 |
def set_cancel_flag():
|
|
|
|
| 668 |
cancel_event.set()
|
| 669 |
print("Cancellation signal sent.")
|
| 670 |
|
| 671 |
-
#
|
| 672 |
submit_event = txt.submit(
|
| 673 |
-
fn=
|
| 674 |
inputs=chat_inputs,
|
| 675 |
-
outputs=
|
| 676 |
-
)
|
| 677 |
-
|
|
|
|
| 678 |
submit_btn.click(
|
| 679 |
-
fn=
|
| 680 |
inputs=chat_inputs,
|
| 681 |
-
outputs=
|
| 682 |
-
)
|
| 683 |
|
| 684 |
-
#
|
|
|
|
| 685 |
cancel_btn.click(
|
| 686 |
fn=set_cancel_flag,
|
| 687 |
-
cancels=[submit_event]
|
| 688 |
)
|
| 689 |
-
|
| 690 |
-
#
|
| 691 |
duration_inputs = [model_dd, search_chk, mr, mc, max_tok, st]
|
| 692 |
for component in duration_inputs:
|
| 693 |
component.change(fn=update_duration_estimate, inputs=duration_inputs, outputs=duration_display)
|
| 694 |
|
| 695 |
-
# Other event listeners
|
| 696 |
search_chk.change(fn=update_default_prompt, inputs=search_chk, outputs=sys_prompt)
|
| 697 |
clr.click(fn=lambda: ([], "", ""), outputs=[chat, txt, dbg])
|
| 698 |
|
|
|
|
| 334 |
model=repo,
|
| 335 |
tokenizer=tokenizer,
|
| 336 |
trust_remote_code=True,
|
| 337 |
+
dtype=dtype, # Use `dtype` instead of deprecated `torch_dtype`
|
| 338 |
device_map="auto",
|
| 339 |
use_cache=True, # Enable past-key-value caching
|
| 340 |
token=access_token)
|
|
|
|
| 509 |
thought_buf = ''
|
| 510 |
answer_buf = ''
|
| 511 |
in_thought = False
|
| 512 |
+
assistant_message_started = False
|
| 513 |
|
| 514 |
# Stream tokens
|
| 515 |
for chunk in streamer:
|
| 516 |
# Check for cancellation signal
|
| 517 |
if cancel_event.is_set():
|
| 518 |
+
if assistant_message_started and history and history[-1]['role'] == 'assistant':
|
| 519 |
+
history[-1]['content'] += " [Generation Canceled]"
|
| 520 |
yield history, debug
|
| 521 |
break
|
| 522 |
|
|
|
|
| 525 |
# Detect start of thinking
|
| 526 |
if not in_thought and '<think>' in text:
|
| 527 |
in_thought = True
|
| 528 |
+
history.append({'role': 'assistant', 'content': '', 'metadata': {'title': '💭 Thought'}})
|
| 529 |
+
assistant_message_started = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 530 |
after = text.split('<think>', 1)[1]
|
| 531 |
thought_buf += after
|
|
|
|
| 532 |
if '</think>' in thought_buf:
|
| 533 |
before, after2 = thought_buf.split('</think>', 1)
|
| 534 |
history[-1]['content'] = before.strip()
|
| 535 |
in_thought = False
|
|
|
|
| 536 |
answer_buf = after2
|
| 537 |
history.append({'role': 'assistant', 'content': answer_buf})
|
| 538 |
else:
|
|
|
|
| 540 |
yield history, debug
|
| 541 |
continue
|
| 542 |
|
|
|
|
| 543 |
if in_thought:
|
| 544 |
thought_buf += text
|
| 545 |
if '</think>' in thought_buf:
|
| 546 |
before, after2 = thought_buf.split('</think>', 1)
|
| 547 |
history[-1]['content'] = before.strip()
|
| 548 |
in_thought = False
|
|
|
|
| 549 |
answer_buf = after2
|
| 550 |
history.append({'role': 'assistant', 'content': answer_buf})
|
| 551 |
else:
|
|
|
|
| 554 |
continue
|
| 555 |
|
| 556 |
# Stream answer
|
| 557 |
+
if not answer_buf and not assistant_message_started:
|
| 558 |
history.append({'role': 'assistant', 'content': ''})
|
| 559 |
+
assistant_message_started = True
|
| 560 |
+
|
| 561 |
answer_buf += text
|
| 562 |
history[-1]['content'] = answer_buf
|
| 563 |
yield history, debug
|
|
|
|
| 568 |
history.append({'role': 'assistant', 'content': f"Error: {e}"})
|
| 569 |
yield history, debug
|
| 570 |
finally:
|
|
|
|
| 571 |
gc.collect()
|
| 572 |
|
| 573 |
|
|
|
|
| 577 |
def update_duration_estimate(model_name, enable_search, max_results, max_chars, max_tokens, search_timeout):
|
| 578 |
"""Calculate and format the estimated GPU duration for current settings."""
|
| 579 |
try:
|
| 580 |
+
dummy_msg, dummy_history, dummy_system_prompt = "", [], ""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 581 |
duration = get_duration(dummy_msg, dummy_history, dummy_system_prompt,
|
| 582 |
enable_search, max_results, max_chars, model_name,
|
| 583 |
max_tokens, 0.7, 40, 0.9, 1.2, search_timeout)
|
|
|
|
| 584 |
model_size = MODELS[model_name].get("params_b", 4.0)
|
| 585 |
+
return (f"⏱️ **Estimated GPU Time: {duration:.1f} seconds**\n\n"
|
| 586 |
+
f"📊 **Model Size:** {model_size:.1f}B parameters\n"
|
| 587 |
+
f"🔍 **Web Search:** {'Enabled' if enable_search else 'Disabled'}")
|
|
|
|
|
|
|
| 588 |
except Exception as e:
|
| 589 |
return f"⚠️ Error calculating estimate: {e}"
|
| 590 |
|
|
|
|
| 600 |
search_chk = gr.Checkbox(label="Enable Web Search", value=False)
|
| 601 |
sys_prompt = gr.Textbox(label="System Prompt", lines=3, value=update_default_prompt(search_chk.value))
|
| 602 |
|
| 603 |
+
duration_display = gr.Markdown(value=update_duration_estimate("Qwen3-1.7B", False, 4, 50, 1024, 5.0))
|
|
|
|
|
|
|
|
|
|
| 604 |
|
| 605 |
gr.Markdown("### Generation Parameters")
|
| 606 |
max_tok = gr.Slider(64, 16384, value=1024, step=32, label="Max Tokens")
|
|
|
|
| 625 |
|
| 626 |
# Group all inputs for cleaner event handling
|
| 627 |
chat_inputs = [txt, chat, sys_prompt, search_chk, mr, mc, model_dd, max_tok, temp, k, p, rp, st]
|
| 628 |
+
# Group all UI components that change state
|
| 629 |
+
interactive_components = [txt, submit_btn, cancel_btn, chat, dbg]
|
| 630 |
+
|
| 631 |
+
def submit_and_manage_ui(user_msg, chat_history, *args):
|
| 632 |
+
"""
|
| 633 |
+
An orchestrator function that manages the UI state and calls the backend chat function.
|
| 634 |
+
It uses a try...finally block to ensure the UI is always reset.
|
| 635 |
+
"""
|
| 636 |
+
# Immediately update UI to a "generating" state
|
| 637 |
yield {
|
| 638 |
+
# Add the user's message to the chat and a placeholder for the response
|
| 639 |
+
chat: chat_history + [[user_msg, None]],
|
| 640 |
+
txt: gr.update(value="", interactive=False),
|
| 641 |
submit_btn: gr.update(interactive=False),
|
| 642 |
cancel_btn: gr.update(visible=True),
|
|
|
|
| 643 |
}
|
| 644 |
+
|
| 645 |
+
try:
|
| 646 |
+
# Package the arguments for the backend function
|
| 647 |
+
backend_args = [user_msg, chat_history] + list(args)
|
| 648 |
+
# Stream the response from the backend
|
| 649 |
+
for response_chunk in chat_response(*backend_args):
|
| 650 |
+
yield {
|
| 651 |
+
chat: response_chunk[0],
|
| 652 |
+
dbg: response_chunk[1],
|
| 653 |
+
}
|
| 654 |
+
except Exception as e:
|
| 655 |
+
print(f"An error occurred during generation: {e}")
|
| 656 |
+
finally:
|
| 657 |
+
# Always reset the UI to an "idle" state, regardless of completion or cancellation
|
| 658 |
+
print("Resetting UI state.")
|
| 659 |
yield {
|
| 660 |
+
txt: gr.update(interactive=True),
|
| 661 |
+
submit_btn: gr.update(interactive=True),
|
| 662 |
+
cancel_btn: gr.update(visible=False),
|
| 663 |
}
|
| 664 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 665 |
def set_cancel_flag():
|
| 666 |
+
"""Called by the cancel button, sets the global event."""
|
| 667 |
cancel_event.set()
|
| 668 |
print("Cancellation signal sent.")
|
| 669 |
|
| 670 |
+
# Event for submitting text via Enter key
|
| 671 |
submit_event = txt.submit(
|
| 672 |
+
fn=submit_and_manage_ui,
|
| 673 |
inputs=chat_inputs,
|
| 674 |
+
outputs=interactive_components,
|
| 675 |
+
)
|
| 676 |
+
|
| 677 |
+
# Event for submitting text via the "Submit" button
|
| 678 |
submit_btn.click(
|
| 679 |
+
fn=submit_and_manage_ui,
|
| 680 |
inputs=chat_inputs,
|
| 681 |
+
outputs=interactive_components,
|
| 682 |
+
)
|
| 683 |
|
| 684 |
+
# Event for the "Cancel" button. It calls the flag-setting function
|
| 685 |
+
# and, crucially, cancels the long-running submit_event.
|
| 686 |
cancel_btn.click(
|
| 687 |
fn=set_cancel_flag,
|
| 688 |
+
cancels=[submit_event]
|
| 689 |
)
|
| 690 |
+
|
| 691 |
+
# Listeners for updating the duration estimate
|
| 692 |
duration_inputs = [model_dd, search_chk, mr, mc, max_tok, st]
|
| 693 |
for component in duration_inputs:
|
| 694 |
component.change(fn=update_duration_estimate, inputs=duration_inputs, outputs=duration_display)
|
| 695 |
|
| 696 |
+
# Other minor event listeners
|
| 697 |
search_chk.change(fn=update_default_prompt, inputs=search_chk, outputs=sys_prompt)
|
| 698 |
clr.click(fn=lambda: ([], "", ""), outputs=[chat, txt, dbg])
|
| 699 |
|