Luigi commited on
Commit
9ac7f36
·
1 Parent(s): a94befb

Improve cancel generation with robust UI state management and orchestrator pattern

Browse files
Files changed (1) hide show
  1. app.py +60 -59
app.py CHANGED
@@ -334,7 +334,7 @@ def load_pipeline(model_name):
334
  model=repo,
335
  tokenizer=tokenizer,
336
  trust_remote_code=True,
337
- torch_dtype=dtype,
338
  device_map="auto",
339
  use_cache=True, # Enable past-key-value caching
340
  token=access_token)
@@ -509,12 +509,14 @@ def chat_response(user_msg, chat_history, system_prompt,
509
  thought_buf = ''
510
  answer_buf = ''
511
  in_thought = False
 
512
 
513
  # Stream tokens
514
  for chunk in streamer:
515
  # Check for cancellation signal
516
  if cancel_event.is_set():
517
- history[-1]['content'] += " [Generation Canceled]"
 
518
  yield history, debug
519
  break
520
 
@@ -523,21 +525,14 @@ def chat_response(user_msg, chat_history, system_prompt,
523
  # Detect start of thinking
524
  if not in_thought and '<think>' in text:
525
  in_thought = True
526
- # Insert thought placeholder
527
- history.append({
528
- 'role': 'assistant',
529
- 'content': '',
530
- 'metadata': {'title': '💭 Thought'}
531
- })
532
- # Capture after opening tag
533
  after = text.split('<think>', 1)[1]
534
  thought_buf += after
535
- # If closing tag in same chunk
536
  if '</think>' in thought_buf:
537
  before, after2 = thought_buf.split('</think>', 1)
538
  history[-1]['content'] = before.strip()
539
  in_thought = False
540
- # Start answer buffer
541
  answer_buf = after2
542
  history.append({'role': 'assistant', 'content': answer_buf})
543
  else:
@@ -545,14 +540,12 @@ def chat_response(user_msg, chat_history, system_prompt,
545
  yield history, debug
546
  continue
547
 
548
- # Continue thought streaming
549
  if in_thought:
550
  thought_buf += text
551
  if '</think>' in thought_buf:
552
  before, after2 = thought_buf.split('</think>', 1)
553
  history[-1]['content'] = before.strip()
554
  in_thought = False
555
- # Start answer buffer
556
  answer_buf = after2
557
  history.append({'role': 'assistant', 'content': answer_buf})
558
  else:
@@ -561,8 +554,10 @@ def chat_response(user_msg, chat_history, system_prompt,
561
  continue
562
 
563
  # Stream answer
564
- if not answer_buf:
565
  history.append({'role': 'assistant', 'content': ''})
 
 
566
  answer_buf += text
567
  history[-1]['content'] = answer_buf
568
  yield history, debug
@@ -573,7 +568,6 @@ def chat_response(user_msg, chat_history, system_prompt,
573
  history.append({'role': 'assistant', 'content': f"Error: {e}"})
574
  yield history, debug
575
  finally:
576
- # Final cleanup
577
  gc.collect()
578
 
579
 
@@ -583,21 +577,14 @@ def update_default_prompt(enable_search):
583
  def update_duration_estimate(model_name, enable_search, max_results, max_chars, max_tokens, search_timeout):
584
  """Calculate and format the estimated GPU duration for current settings."""
585
  try:
586
- # Create dummy values for the other parameters that get_duration expects
587
- dummy_msg = ""
588
- dummy_history = []
589
- dummy_system_prompt = ""
590
-
591
  duration = get_duration(dummy_msg, dummy_history, dummy_system_prompt,
592
  enable_search, max_results, max_chars, model_name,
593
  max_tokens, 0.7, 40, 0.9, 1.2, search_timeout)
594
-
595
  model_size = MODELS[model_name].get("params_b", 4.0)
596
- use_aot = model_size >= 2
597
-
598
- return f"⏱️ **Estimated GPU Time: {duration:.1f} seconds**\n\n" \
599
- f"📊 **Model Size:** {model_size:.1f}B parameters\n" \
600
- f"🔍 **Web Search:** {'Enabled' if enable_search else 'Disabled'}"
601
  except Exception as e:
602
  return f"⚠️ Error calculating estimate: {e}"
603
 
@@ -613,10 +600,7 @@ with gr.Blocks(title="LLM Inference with ZeroGPU") as demo:
613
  search_chk = gr.Checkbox(label="Enable Web Search", value=False)
614
  sys_prompt = gr.Textbox(label="System Prompt", lines=3, value=update_default_prompt(search_chk.value))
615
 
616
- # GPU Time Estimate Display
617
- duration_display = gr.Markdown(value=update_duration_estimate(
618
- "Qwen3-1.7B", False, 4, 50, 1024, 5.0
619
- ))
620
 
621
  gr.Markdown("### Generation Parameters")
622
  max_tok = gr.Slider(64, 16384, value=1024, step=32, label="Max Tokens")
@@ -641,58 +625,75 @@ with gr.Blocks(title="LLM Inference with ZeroGPU") as demo:
641
 
642
  # Group all inputs for cleaner event handling
643
  chat_inputs = [txt, chat, sys_prompt, search_chk, mr, mc, model_dd, max_tok, temp, k, p, rp, st]
644
-
645
- def start_generation_and_update_ui(*args):
646
- # Update UI to "generating" state
 
 
 
 
 
 
647
  yield {
 
 
 
648
  submit_btn: gr.update(interactive=False),
649
  cancel_btn: gr.update(visible=True),
650
- txt: gr.update(interactive=False, value=""), # Clear textbox and disable
651
  }
652
- # Call the actual chat response generator
653
- for output in chat_response(*args):
 
 
 
 
 
 
 
 
 
 
 
 
 
654
  yield {
655
- chat: output[0],
656
- dbg: output[1]
 
657
  }
658
 
659
- def reset_ui_after_generation():
660
- # Update UI back to "idle" state
661
- return {
662
- submit_btn: gr.update(interactive=True),
663
- cancel_btn: gr.update(visible=False),
664
- txt: gr.update(interactive=True), # Re-enable textbox
665
- }
666
-
667
  def set_cancel_flag():
 
668
  cancel_event.set()
669
  print("Cancellation signal sent.")
670
 
671
- # When the user submits their message (via button or enter)
672
  submit_event = txt.submit(
673
- fn=start_generation_and_update_ui,
674
  inputs=chat_inputs,
675
- outputs=[chat, dbg, submit_btn, cancel_btn, txt]
676
- ).then(fn=reset_ui_after_generation, outputs=[submit_btn, cancel_btn, txt])
677
-
 
678
  submit_btn.click(
679
- fn=start_generation_and_update_ui,
680
  inputs=chat_inputs,
681
- outputs=[chat, dbg, submit_btn, cancel_btn, txt]
682
- ).then(fn=reset_ui_after_generation, outputs=[submit_btn, cancel_btn, txt])
683
 
684
- # When the user clicks the cancel button
 
685
  cancel_btn.click(
686
  fn=set_cancel_flag,
687
- cancels=[submit_event] # This tells Gradio to stop the running `submit_event`
688
  )
689
-
690
- # Update duration estimate when relevant inputs change
691
  duration_inputs = [model_dd, search_chk, mr, mc, max_tok, st]
692
  for component in duration_inputs:
693
  component.change(fn=update_duration_estimate, inputs=duration_inputs, outputs=duration_display)
694
 
695
- # Other event listeners
696
  search_chk.change(fn=update_default_prompt, inputs=search_chk, outputs=sys_prompt)
697
  clr.click(fn=lambda: ([], "", ""), outputs=[chat, txt, dbg])
698
 
 
334
  model=repo,
335
  tokenizer=tokenizer,
336
  trust_remote_code=True,
337
+ dtype=dtype, # Use `dtype` instead of deprecated `torch_dtype`
338
  device_map="auto",
339
  use_cache=True, # Enable past-key-value caching
340
  token=access_token)
 
509
  thought_buf = ''
510
  answer_buf = ''
511
  in_thought = False
512
+ assistant_message_started = False
513
 
514
  # Stream tokens
515
  for chunk in streamer:
516
  # Check for cancellation signal
517
  if cancel_event.is_set():
518
+ if assistant_message_started and history and history[-1]['role'] == 'assistant':
519
+ history[-1]['content'] += " [Generation Canceled]"
520
  yield history, debug
521
  break
522
 
 
525
  # Detect start of thinking
526
  if not in_thought and '<think>' in text:
527
  in_thought = True
528
+ history.append({'role': 'assistant', 'content': '', 'metadata': {'title': '💭 Thought'}})
529
+ assistant_message_started = True
 
 
 
 
 
530
  after = text.split('<think>', 1)[1]
531
  thought_buf += after
 
532
  if '</think>' in thought_buf:
533
  before, after2 = thought_buf.split('</think>', 1)
534
  history[-1]['content'] = before.strip()
535
  in_thought = False
 
536
  answer_buf = after2
537
  history.append({'role': 'assistant', 'content': answer_buf})
538
  else:
 
540
  yield history, debug
541
  continue
542
 
 
543
  if in_thought:
544
  thought_buf += text
545
  if '</think>' in thought_buf:
546
  before, after2 = thought_buf.split('</think>', 1)
547
  history[-1]['content'] = before.strip()
548
  in_thought = False
 
549
  answer_buf = after2
550
  history.append({'role': 'assistant', 'content': answer_buf})
551
  else:
 
554
  continue
555
 
556
  # Stream answer
557
+ if not answer_buf and not assistant_message_started:
558
  history.append({'role': 'assistant', 'content': ''})
559
+ assistant_message_started = True
560
+
561
  answer_buf += text
562
  history[-1]['content'] = answer_buf
563
  yield history, debug
 
568
  history.append({'role': 'assistant', 'content': f"Error: {e}"})
569
  yield history, debug
570
  finally:
 
571
  gc.collect()
572
 
573
 
 
577
  def update_duration_estimate(model_name, enable_search, max_results, max_chars, max_tokens, search_timeout):
578
  """Calculate and format the estimated GPU duration for current settings."""
579
  try:
580
+ dummy_msg, dummy_history, dummy_system_prompt = "", [], ""
 
 
 
 
581
  duration = get_duration(dummy_msg, dummy_history, dummy_system_prompt,
582
  enable_search, max_results, max_chars, model_name,
583
  max_tokens, 0.7, 40, 0.9, 1.2, search_timeout)
 
584
  model_size = MODELS[model_name].get("params_b", 4.0)
585
+ return (f"⏱️ **Estimated GPU Time: {duration:.1f} seconds**\n\n"
586
+ f"📊 **Model Size:** {model_size:.1f}B parameters\n"
587
+ f"🔍 **Web Search:** {'Enabled' if enable_search else 'Disabled'}")
 
 
588
  except Exception as e:
589
  return f"⚠️ Error calculating estimate: {e}"
590
 
 
600
  search_chk = gr.Checkbox(label="Enable Web Search", value=False)
601
  sys_prompt = gr.Textbox(label="System Prompt", lines=3, value=update_default_prompt(search_chk.value))
602
 
603
+ duration_display = gr.Markdown(value=update_duration_estimate("Qwen3-1.7B", False, 4, 50, 1024, 5.0))
 
 
 
604
 
605
  gr.Markdown("### Generation Parameters")
606
  max_tok = gr.Slider(64, 16384, value=1024, step=32, label="Max Tokens")
 
625
 
626
  # Group all inputs for cleaner event handling
627
  chat_inputs = [txt, chat, sys_prompt, search_chk, mr, mc, model_dd, max_tok, temp, k, p, rp, st]
628
+ # Group all UI components that change state
629
+ interactive_components = [txt, submit_btn, cancel_btn, chat, dbg]
630
+
631
+ def submit_and_manage_ui(user_msg, chat_history, *args):
632
+ """
633
+ An orchestrator function that manages the UI state and calls the backend chat function.
634
+ It uses a try...finally block to ensure the UI is always reset.
635
+ """
636
+ # Immediately update UI to a "generating" state
637
  yield {
638
+ # Add the user's message to the chat and a placeholder for the response
639
+ chat: chat_history + [[user_msg, None]],
640
+ txt: gr.update(value="", interactive=False),
641
  submit_btn: gr.update(interactive=False),
642
  cancel_btn: gr.update(visible=True),
 
643
  }
644
+
645
+ try:
646
+ # Package the arguments for the backend function
647
+ backend_args = [user_msg, chat_history] + list(args)
648
+ # Stream the response from the backend
649
+ for response_chunk in chat_response(*backend_args):
650
+ yield {
651
+ chat: response_chunk[0],
652
+ dbg: response_chunk[1],
653
+ }
654
+ except Exception as e:
655
+ print(f"An error occurred during generation: {e}")
656
+ finally:
657
+ # Always reset the UI to an "idle" state, regardless of completion or cancellation
658
+ print("Resetting UI state.")
659
  yield {
660
+ txt: gr.update(interactive=True),
661
+ submit_btn: gr.update(interactive=True),
662
+ cancel_btn: gr.update(visible=False),
663
  }
664
 
 
 
 
 
 
 
 
 
665
  def set_cancel_flag():
666
+ """Called by the cancel button, sets the global event."""
667
  cancel_event.set()
668
  print("Cancellation signal sent.")
669
 
670
+ # Event for submitting text via Enter key
671
  submit_event = txt.submit(
672
+ fn=submit_and_manage_ui,
673
  inputs=chat_inputs,
674
+ outputs=interactive_components,
675
+ )
676
+
677
+ # Event for submitting text via the "Submit" button
678
  submit_btn.click(
679
+ fn=submit_and_manage_ui,
680
  inputs=chat_inputs,
681
+ outputs=interactive_components,
682
+ )
683
 
684
+ # Event for the "Cancel" button. It calls the flag-setting function
685
+ # and, crucially, cancels the long-running submit_event.
686
  cancel_btn.click(
687
  fn=set_cancel_flag,
688
+ cancels=[submit_event]
689
  )
690
+
691
+ # Listeners for updating the duration estimate
692
  duration_inputs = [model_dd, search_chk, mr, mc, max_tok, st]
693
  for component in duration_inputs:
694
  component.change(fn=update_duration_estimate, inputs=duration_inputs, outputs=duration_display)
695
 
696
+ # Other minor event listeners
697
  search_chk.change(fn=update_default_prompt, inputs=search_chk, outputs=sys_prompt)
698
  clr.click(fn=lambda: ([], "", ""), outputs=[chat, txt, dbg])
699