smirki commited on
Commit
46823cb
·
verified ·
1 Parent(s): 91b3f5a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -63
app.py CHANGED
@@ -1,47 +1,57 @@
1
  import subprocess
2
- subprocess.run(
3
- 'pip install flash-attn==2.7.0.post2 --no-build-isolation',
4
- env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"},
5
- shell=True
6
- )
7
- subprocess.run(
8
- 'pip install transformers',
9
- shell=True
10
- )
11
-
12
- import spaces
13
  import os
14
  import re
15
  import logging
 
16
  from typing import List
17
  from threading import Thread
18
- import base64
19
 
20
  import torch
21
  import gradio as gr
22
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
23
 
 
 
 
 
 
 
 
 
 
24
  # ----------------------------------------------------------------------
25
- # 1. Setup Model & Tokenizer
26
  # ----------------------------------------------------------------------
27
- model_name = 'smirki/UIGEN-T1.1-Qwen-7B' # Change as needed
28
- use_thread = True # Generation happens in a background thread
29
-
30
  logger = logging.getLogger(__name__)
31
- logging.getLogger("httpx").setLevel(logging.WARNING)
32
  logging.basicConfig(level=logging.INFO)
 
 
 
 
 
 
33
 
34
- logger.info("Loading model and tokenizer...")
35
  model = AutoModelForCausalLM.from_pretrained(
36
  model_name,
37
- torch_dtype=torch.bfloat16,
 
38
  trust_remote_code=True
39
- ).to("cuda")
40
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
 
 
 
 
 
 
 
 
 
41
  logger.info("Model and tokenizer loaded successfully.")
42
 
43
  # ----------------------------------------------------------------------
44
- # 2. Two-Phase Prompt Templates
45
  # ----------------------------------------------------------------------
46
  s1_inference_prompt_think_only = """<|im_start|>user
47
  {question}<|im_end|>
@@ -49,29 +59,27 @@ s1_inference_prompt_think_only = """<|im_start|>user
49
  <|im_start|>think
50
  """
51
 
52
- # ----------------------------------------------------------------------
53
- # 3. Generation Parameter Setup
54
- # ----------------------------------------------------------------------
55
  THINK_MAX_NEW_TOKENS = 12000
56
  ANSWER_MAX_NEW_TOKENS = 12000
57
 
58
  def initialize_gen_kwargs():
 
59
  return {
60
- "max_new_tokens": 1024, # default; will be overwritten per phase
61
  "do_sample": True,
62
  "temperature": 0.7,
63
  "top_p": 0.9,
64
  "repetition_penalty": 1.05,
65
- # "eos_token_id": model.generation_config.eos_token_id, # Removed to avoid premature stopping
66
  "pad_token_id": tokenizer.pad_token_id,
67
  "use_cache": True,
68
- "streamer": None # dynamically added
69
  }
70
 
71
  # ----------------------------------------------------------------------
72
  # 4. Helper to submit chat
73
  # ----------------------------------------------------------------------
74
  def submit_chat(chatbot, text_input):
 
75
  if not text_input.strip():
76
  return chatbot, ""
77
  response = ""
@@ -83,6 +91,10 @@ def submit_chat(chatbot, text_input):
83
  # 5. Artifacts Handling
84
  # ----------------------------------------------------------------------
85
  def extract_html_code_block(text: str) -> str:
 
 
 
 
86
  pattern = r'```html\s*(.*?)\s*```'
87
  match = re.search(pattern, text, re.DOTALL)
88
  if match:
@@ -93,6 +105,10 @@ def extract_html_code_block(text: str) -> str:
93
  return text.strip()
94
 
95
  def send_to_sandbox(html_code: str) -> str:
 
 
 
 
96
  encoded_html = base64.b64encode(html_code.encode('utf-8')).decode('utf-8')
97
  data_uri = f"data:text/html;charset=utf-8;base64,{encoded_html}"
98
  return f'<iframe src="{data_uri}" width="100%" height="920px"></iframe>'
@@ -100,18 +116,27 @@ def send_to_sandbox(html_code: str) -> str:
100
  # ----------------------------------------------------------------------
101
  # 6. The Two-Phase Streaming Inference
102
  # ----------------------------------------------------------------------
103
- @spaces.GPU
104
  def ovis_chat(chatbot: List[List[str]]):
 
 
 
 
 
 
105
  logger.info("Starting two-phase generation...")
106
- # Phase 1: "think" phase
107
- last_query = chatbot[-1][0]
 
108
  formatted_think_prompt = s1_inference_prompt_think_only.format(question=last_query)
109
- logger.info("Formatted think prompt.")
110
-
111
- input_ids_think = tokenizer.encode(formatted_think_prompt, return_tensors="pt").to(model.device)
112
- attention_mask_think = torch.ne(input_ids_think, tokenizer.pad_token_id).to(model.device)
 
 
113
  think_inputs = {"input_ids": input_ids_think, "attention_mask": attention_mask_think}
114
-
 
115
  gen_kwargs_think = initialize_gen_kwargs()
116
  gen_kwargs_think["max_new_tokens"] = THINK_MAX_NEW_TOKENS
117
  think_streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
@@ -120,27 +145,30 @@ def ovis_chat(chatbot: List[List[str]]):
120
  full_think = ""
121
  try:
122
  with torch.inference_mode():
123
- logger.info("Starting think phase generation thread...")
124
  thread_think = Thread(target=lambda: model.generate(**think_inputs, **gen_kwargs_think))
125
  thread_think.start()
 
126
  for new_text in think_streamer:
127
  full_think += new_text
128
- logger.info(f"Think phase token: {new_text.strip()}")
 
 
 
129
  display_text = f"<|im_start|>think\n{full_think.strip()}"
130
  chatbot[-1][1] = display_text
131
  yield chatbot, ""
132
  thread_think.join()
133
- logger.info("Think phase completed.")
134
  except Exception as e:
135
  logger.error("Error during think phase: " + str(e))
136
  yield chatbot, f"Error in think phase: {str(e)}"
137
  return
 
138
 
139
- # Phase 2: "answer" phase
140
  new_prompt = formatted_think_prompt + full_think.strip() + "\n<|im_start|>answer\n"
141
- logger.info("Constructed prompt for answer phase.")
142
- input_ids_answer = tokenizer.encode(new_prompt, return_tensors="pt").to(model.device)
143
- attention_mask_answer = torch.ne(input_ids_answer, tokenizer.pad_token_id).to(model.device)
144
  answer_inputs = {"input_ids": input_ids_answer, "attention_mask": attention_mask_answer}
145
 
146
  gen_kwargs_answer = initialize_gen_kwargs()
@@ -151,12 +179,12 @@ def ovis_chat(chatbot: List[List[str]]):
151
  full_answer = ""
152
  try:
153
  with torch.inference_mode():
154
- logger.info("Starting answer phase generation thread...")
155
  thread_answer = Thread(target=lambda: model.generate(**answer_inputs, **gen_kwargs_answer))
156
  thread_answer.start()
157
  for new_text in answer_streamer:
158
  full_answer += new_text
159
- logger.info(f"Answer phase token: {new_text.strip()}")
 
160
  display_text = (
161
  f"<|im_start|>think\n{full_think.strip()}\n\n"
162
  f"<|im_start|>answer\n{full_answer.strip()}"
@@ -164,13 +192,16 @@ def ovis_chat(chatbot: List[List[str]]):
164
  chatbot[-1][1] = display_text
165
  yield chatbot, ""
166
  thread_answer.join()
167
- logger.info("Answer phase completed.")
168
  except Exception as e:
169
  logger.error("Error during answer phase: " + str(e))
170
  yield chatbot, f"Error in answer phase: {str(e)}"
171
  return
 
172
 
 
173
  log_conversation(chatbot)
 
 
174
  html_code = extract_html_code_block(full_answer)
175
  sandbox_iframe = send_to_sandbox(html_code)
176
  yield chatbot, sandbox_iframe
@@ -197,21 +228,18 @@ css_code = """
197
  justify-content: center;
198
  align-items: center;
199
  }
200
-
201
  .right_panel {
202
  margin-top: 16px;
203
  border: 1px solid #BFBFC4;
204
  border-radius: 8px;
205
  overflow: hidden;
206
  }
207
-
208
  .render_header {
209
  height: 30px;
210
  width: 100%;
211
  padding: 5px 16px;
212
  background-color: #f5f5f5;
213
  }
214
-
215
  .header_btn {
216
  display: inline-block;
217
  height: 10px;
@@ -219,18 +247,15 @@ css_code = """
219
  border-radius: 50%;
220
  margin-right: 4px;
221
  }
222
-
223
  .render_header > .header_btn:nth-child(1) {
224
  background-color: #f5222d;
225
  }
226
-
227
  .render_header > .header_btn:nth-child(2) {
228
  background-color: #faad14;
229
  }
230
  .render_header > .header_btn:nth-child(3) {
231
  background-color: #52c41a;
232
  }
233
-
234
  .right_content {
235
  height: 920px;
236
  display: flex;
@@ -238,7 +263,6 @@ css_code = """
238
  justify-content: center;
239
  align-items: center;
240
  }
241
-
242
  .html_content {
243
  width: 100%;
244
  height: 920px;
@@ -265,11 +289,7 @@ with gr.Blocks(title=model_name.split('/')[-1], css=css_code) as demo:
265
 
266
  with gr.Row():
267
  with gr.Column(scale=4):
268
- chatbot = gr.Chatbot(
269
- label="Chat",
270
- height=520,
271
- show_copy_button=True
272
- )
273
  with gr.Row():
274
  text_input = gr.Textbox(
275
  label="Prompt",
@@ -280,11 +300,12 @@ with gr.Blocks(title=model_name.split('/')[-1], css=css_code) as demo:
280
  submit_btn = gr.Button("Send", variant="primary")
281
  clear_btn = gr.Button("Clear", variant="secondary")
282
  with gr.Column(scale=6):
283
- gr.HTML('<div class="render_header"><span class="header_btn"></span><span class="header_btn"></span><span class="header_btn"></span></div>')
284
- artifact_html = gr.HTML(
285
- value="",
286
- elem_classes="html_content"
287
  )
 
288
 
289
  submit_btn.click(
290
  submit_chat, [chatbot, text_input], [chatbot, text_input]
@@ -303,5 +324,5 @@ with gr.Blocks(title=model_name.split('/')[-1], css=css_code) as demo:
303
  outputs=[chatbot, text_input, artifact_html]
304
  )
305
 
306
- logger.info("Launching demo with GPU support...")
307
- demo.queue(default_concurrency_limit=1).launch(server_name="0.0.0.0", share=True)
 
1
  import subprocess
 
 
 
 
 
 
 
 
 
 
 
2
  import os
3
  import re
4
  import logging
5
+ import base64
6
  from typing import List
7
  from threading import Thread
 
8
 
9
  import torch
10
  import gradio as gr
11
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
12
 
13
+ # Install packages (if needed). Adjust or remove if your environment already has them.
14
+ subprocess.run(
15
+ ["pip", "install", "flash-attn==2.7.0.post2", "--no-build-isolation"]
16
+ )
17
+ subprocess.run(["pip", "install", "transformers"])
18
+
19
+ # Optional: set up CUDA-specific environment vars if you need them
20
+ # os.environ["FLASH_ATTENTION_SKIP_CUDA_BUILD"] = "TRUE"
21
+
22
  # ----------------------------------------------------------------------
23
+ # 1. Setup Logging
24
  # ----------------------------------------------------------------------
 
 
 
25
  logger = logging.getLogger(__name__)
 
26
  logging.basicConfig(level=logging.INFO)
27
+ logging.getLogger("httpx").setLevel(logging.WARNING)
28
+
29
+ # ----------------------------------------------------------------------
30
+ # 2. Model & Tokenizer Initialization
31
+ # ----------------------------------------------------------------------
32
+ model_name = "smirki/UIGEN-T1.1-Qwen-7B" # adjust as needed
33
 
34
+ logger.info("Loading model & tokenizer...")
35
  model = AutoModelForCausalLM.from_pretrained(
36
  model_name,
37
+ device_map="auto", # auto-shard across available GPU(s)
38
+ torch_dtype=torch.bfloat16, # or torch.float16, depending on your hardware
39
  trust_remote_code=True
40
+ )
41
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
42
+
43
+ # Optional speed-up via torch.compile (requires PyTorch ≥ 2.0).
44
+ # Comment out if you run into any compatibility issues.
45
+ try:
46
+ model = torch.compile(model)
47
+ logger.info("Model compiled with torch.compile for potential speed-up.")
48
+ except Exception as e:
49
+ logger.warning(f"Could not compile model: {e}")
50
+
51
  logger.info("Model and tokenizer loaded successfully.")
52
 
53
  # ----------------------------------------------------------------------
54
+ # 3. Two-Phase Prompt Templates
55
  # ----------------------------------------------------------------------
56
  s1_inference_prompt_think_only = """<|im_start|>user
57
  {question}<|im_end|>
 
59
  <|im_start|>think
60
  """
61
 
 
 
 
62
  THINK_MAX_NEW_TOKENS = 12000
63
  ANSWER_MAX_NEW_TOKENS = 12000
64
 
65
  def initialize_gen_kwargs():
66
+ """Common generation parameters for both phases; tweak as necessary."""
67
  return {
68
+ "max_new_tokens": 1024, # will be updated for each phase
69
  "do_sample": True,
70
  "temperature": 0.7,
71
  "top_p": 0.9,
72
  "repetition_penalty": 1.05,
 
73
  "pad_token_id": tokenizer.pad_token_id,
74
  "use_cache": True,
75
+ "streamer": None # will be replaced with TextIteratorStreamer
76
  }
77
 
78
  # ----------------------------------------------------------------------
79
  # 4. Helper to submit chat
80
  # ----------------------------------------------------------------------
81
  def submit_chat(chatbot, text_input):
82
+ """Adds the user query to the Chatbot list, clearing the textbox."""
83
  if not text_input.strip():
84
  return chatbot, ""
85
  response = ""
 
91
  # 5. Artifacts Handling
92
  # ----------------------------------------------------------------------
93
  def extract_html_code_block(text: str) -> str:
94
+ """
95
+ Extracts the first ```html ... ``` block from the model's answer.
96
+ If none found, returns the entire text stripped.
97
+ """
98
  pattern = r'```html\s*(.*?)\s*```'
99
  match = re.search(pattern, text, re.DOTALL)
100
  if match:
 
105
  return text.strip()
106
 
107
  def send_to_sandbox(html_code: str) -> str:
108
+ """
109
+ Converts HTML code into a base64-encoded Data URI embedded in an iframe,
110
+ which can be displayed in Gradio’s HTML component.
111
+ """
112
  encoded_html = base64.b64encode(html_code.encode('utf-8')).decode('utf-8')
113
  data_uri = f"data:text/html;charset=utf-8;base64,{encoded_html}"
114
  return f'<iframe src="{data_uri}" width="100%" height="920px"></iframe>'
 
116
  # ----------------------------------------------------------------------
117
  # 6. The Two-Phase Streaming Inference
118
  # ----------------------------------------------------------------------
 
119
  def ovis_chat(chatbot: List[List[str]]):
120
+ """
121
+ Main two-phase pipeline:
122
+ 1) "Think" phase (hidden chain-of-thought)
123
+ 2) "Answer" phase
124
+ Yields intermediate partial results for real-time streaming in Gradio.
125
+ """
126
  logger.info("Starting two-phase generation...")
127
+
128
+ # -- Phase 1: "think" --
129
+ last_query = chatbot[-1][0] # latest user query
130
  formatted_think_prompt = s1_inference_prompt_think_only.format(question=last_query)
131
+
132
+ # Prepare input
133
+ input_ids_think = tokenizer.encode(formatted_think_prompt, return_tensors="pt")
134
+ attention_mask_think = torch.ne(input_ids_think, tokenizer.pad_token_id)
135
+ # Move to correct device automatically if using device_map="auto"
136
+ # or if single GPU, you can do e.g. input_ids_think = input_ids_think.cuda()
137
  think_inputs = {"input_ids": input_ids_think, "attention_mask": attention_mask_think}
138
+
139
+ # Generation params
140
  gen_kwargs_think = initialize_gen_kwargs()
141
  gen_kwargs_think["max_new_tokens"] = THINK_MAX_NEW_TOKENS
142
  think_streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
 
145
  full_think = ""
146
  try:
147
  with torch.inference_mode():
 
148
  thread_think = Thread(target=lambda: model.generate(**think_inputs, **gen_kwargs_think))
149
  thread_think.start()
150
+ # Stream partial tokens as they arrive
151
  for new_text in think_streamer:
152
  full_think += new_text
153
+ # If you don’t need every single token logged, skip or reduce:
154
+ # logger.debug(f"Think token: {new_text.strip()}")
155
+
156
+ # Show partial chain-of-thought in the Chatbot’s assistant window
157
  display_text = f"<|im_start|>think\n{full_think.strip()}"
158
  chatbot[-1][1] = display_text
159
  yield chatbot, ""
160
  thread_think.join()
 
161
  except Exception as e:
162
  logger.error("Error during think phase: " + str(e))
163
  yield chatbot, f"Error in think phase: {str(e)}"
164
  return
165
+ logger.info("Think phase completed.")
166
 
167
+ # -- Phase 2: "answer" --
168
  new_prompt = formatted_think_prompt + full_think.strip() + "\n<|im_start|>answer\n"
169
+
170
+ input_ids_answer = tokenizer.encode(new_prompt, return_tensors="pt")
171
+ attention_mask_answer = torch.ne(input_ids_answer, tokenizer.pad_token_id)
172
  answer_inputs = {"input_ids": input_ids_answer, "attention_mask": attention_mask_answer}
173
 
174
  gen_kwargs_answer = initialize_gen_kwargs()
 
179
  full_answer = ""
180
  try:
181
  with torch.inference_mode():
 
182
  thread_answer = Thread(target=lambda: model.generate(**answer_inputs, **gen_kwargs_answer))
183
  thread_answer.start()
184
  for new_text in answer_streamer:
185
  full_answer += new_text
186
+ # logger.debug(f"Answer token: {new_text.strip()}")
187
+
188
  display_text = (
189
  f"<|im_start|>think\n{full_think.strip()}\n\n"
190
  f"<|im_start|>answer\n{full_answer.strip()}"
 
192
  chatbot[-1][1] = display_text
193
  yield chatbot, ""
194
  thread_answer.join()
 
195
  except Exception as e:
196
  logger.error("Error during answer phase: " + str(e))
197
  yield chatbot, f"Error in answer phase: {str(e)}"
198
  return
199
+ logger.info("Answer phase completed.")
200
 
201
+ # Logging the final conversation
202
  log_conversation(chatbot)
203
+
204
+ # Extract HTML code if any & display
205
  html_code = extract_html_code_block(full_answer)
206
  sandbox_iframe = send_to_sandbox(html_code)
207
  yield chatbot, sandbox_iframe
 
228
  justify-content: center;
229
  align-items: center;
230
  }
 
231
  .right_panel {
232
  margin-top: 16px;
233
  border: 1px solid #BFBFC4;
234
  border-radius: 8px;
235
  overflow: hidden;
236
  }
 
237
  .render_header {
238
  height: 30px;
239
  width: 100%;
240
  padding: 5px 16px;
241
  background-color: #f5f5f5;
242
  }
 
243
  .header_btn {
244
  display: inline-block;
245
  height: 10px;
 
247
  border-radius: 50%;
248
  margin-right: 4px;
249
  }
 
250
  .render_header > .header_btn:nth-child(1) {
251
  background-color: #f5222d;
252
  }
 
253
  .render_header > .header_btn:nth-child(2) {
254
  background-color: #faad14;
255
  }
256
  .render_header > .header_btn:nth-child(3) {
257
  background-color: #52c41a;
258
  }
 
259
  .right_content {
260
  height: 920px;
261
  display: flex;
 
263
  justify-content: center;
264
  align-items: center;
265
  }
 
266
  .html_content {
267
  width: 100%;
268
  height: 920px;
 
289
 
290
  with gr.Row():
291
  with gr.Column(scale=4):
292
+ chatbot = gr.Chatbot(label="Chat", height=520, show_copy_button=True)
 
 
 
 
293
  with gr.Row():
294
  text_input = gr.Textbox(
295
  label="Prompt",
 
300
  submit_btn = gr.Button("Send", variant="primary")
301
  clear_btn = gr.Button("Clear", variant="secondary")
302
  with gr.Column(scale=6):
303
+ gr.HTML(
304
+ '<div class="render_header">'
305
+ '<span class="header_btn"></span><span class="header_btn"></span><span class="header_btn"></span>'
306
+ '</div>'
307
  )
308
+ artifact_html = gr.HTML(value="", elem_classes="html_content")
309
 
310
  submit_btn.click(
311
  submit_chat, [chatbot, text_input], [chatbot, text_input]
 
324
  outputs=[chatbot, text_input, artifact_html]
325
  )
326
 
327
+ logger.info("Launching Gradio app. Please wait...")
328
+ demo.queue(concurrency_count=10).launch(server_name="0.0.0.0", share=True)